# Deep Kernel (DKL) code aligned with your Function Encoder setup
import matplotlib.pyplot as plt
import tqdm
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
from sklearn.metrics.pairwise import polynomial_kernel

from torch.utils.data import DataLoader
from my_datasets.polynomial import PolynomialDataset
from function_encoder.model.mlp import MultiHeadedMLP

import time
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

# --- device ---
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

torch.manual_seed(42)


# Load dataset
dataset = PolynomialDataset(n_points=100, n_example_points=50)
dataloader = DataLoader(dataset, batch_size=100)
dataloader_iter = iter(dataloader)

theta = MultiHeadedMLP(layer_sizes=[1, 32, 1], num_heads=4)


class DeepLinearKernel(nn.Module):
    def __init__(self, log_lam0=0.0):
        super().__init__()
        self.log_lam = nn.Parameter(torch.tensor(float(log_lam0)))

    def forward(self, zi, zj):
        return (zi @ zj.T)


kernel = DeepLinearKernel(log_lam0=-3).to(device)


def dkl_batch_loss(X, y, example_X, example_y):
    B, P, _ = X.shape
    _, Ps, _ = example_X.shape
    lam = kernel.log_lam.exp()

    # embeddings
    Z_S = theta(example_X.view(B*Ps, 1)).view(B, Ps, -1)  # [B, Ps, D]
    Z_all = theta(X.view(B*P, 1)).view(B, P, -1)          # [B, P, D]

    total = 0.0
    dtype = X.dtype

    for b in range(B):
        zS = Z_S[b]        # [Ps, D]
        zA = Z_all[b]      # [P, D]
        yS = example_y[b]  # [Ps, 1]
        yA = y[b]          # [P, 1]

        K_SS = kernel(zS, zS)            # [Ps, Ps]
        K_allS = kernel(zA, zS)            # [P, Ps]

        A = K_SS + lam * torch.eye(Ps, device=K_SS.device, dtype=dtype)
        alpha = torch.linalg.solve(A, yS)     # [Ps, 1]

        y_hat = K_allS @ alpha                # [P, 1]
        total = total + F.mse_loss(y_hat, yA)

    return total / B


num_epochs = 500
lr = 1e-1
opt = torch.optim.Adam([
    {"params": theta.parameters()},
    {"params": kernel.parameters(), "weight_decay": 0.0}
], lr=lr)

start = time.perf_counter()
with tqdm.tqdm(range(num_epochs)) as tqdm_bar:
    for epoch in tqdm_bar:
        batch = next(dataloader_iter)
        X, y, example_X, example_y = batch
        X = X.to(device)
        y = y.to(device)
        example_X = example_X.to(device)
        example_y = example_y.to(device)

        opt.zero_grad(set_to_none=True)
        loss = dkl_batch_loss(X, y, example_X, example_y)
        loss.backward()
        opt.step()
        tqdm_bar.set_postfix({"loss": f"{loss:.2e}"})
end = time.perf_counter()
print(f"Wall time training: {end - start:.6f} s")


theta.eval()
kernel.eval()
with torch.no_grad():
    torch.manual_seed(123)
    val_loader = DataLoader(dataset, batch_size=1)
    X, y, example_X, example_y = next(iter(val_loader))

    X = X.to(device)
    y = y.to(device)
    example_X = example_X.to(device)
    example_y = example_y.to(device)

    idx = torch.argsort(X, dim=1, descending=False)
    X_sorted = torch.gather(X, dim=1, index=idx)
    y_sorted = torch.gather(y, dim=1, index=idx)

    B, P, _ = X.shape
    _, Ps, _ = example_X.shape
    start = time.perf_counter()

    Z_S = theta(example_X.view(B*Ps, 1)).view(B, Ps, -1)  # [1,Ps,D]
    Z_all = theta(X_sorted.view(B*P, 1)).view(B, P, -1)     # [1,P,D]

    zS = Z_S[0]         # [Ps, D]
    zA = Z_all[0]       # [P, D]
    yS = example_y[0]   # [Ps, 1]
    yA = y_sorted[0]    # [P, 1]

    K_SS = kernel(zS, zS)             # [Ps, Ps]
    K_allS = kernel(zA, zS)             # [P, Ps]
    lam = kernel.log_lam.exp()

    A = K_SS + lam * torch.eye(Ps, device=K_SS.device, dtype=K_SS.dtype)
    alpha = torch.linalg.solve(A, yS)
    y_hat_all = K_allS @ alpha          # [P,1]

    end = time.perf_counter()
    print(f"Wall time prediction: {end - start:.6f} s")

    mse_val = F.mse_loss(y_hat_all, yA)
    print("Validation MSE:", mse_val.item())

    ckpt = torch.load("fe_basis_ckpt.pt", map_location=device)
    theta2 = MultiHeadedMLP(**ckpt["arch"]).to(device)
    theta2.load_state_dict(ckpt["state"])
    theta2.eval()

    Z_S2 = theta2(example_X.view(B*Ps, 1)).view(B, Ps, -1)
    zS2 = Z_S2[0]       # [Ps, D]
    K_SS2 = kernel(zS2, zS2)

    def _center(K):
        n = K.size(0)
        H = torch.eye(n, device=K.device, dtype=K.dtype) - \
            torch.ones(n, n, device=K.device, dtype=K.dtype)/n
        return H @ K @ H

    def cka(K1, K2):
        K1c, K2c = _center(K1), _center(K2)
        num = (K1c * K2c).sum()
        den = torch.linalg.norm(K1c, 'fro') * torch.linalg.norm(K2c, 'fro')
        return (num / den).item()

    print("CKA(K_fe, K_dkl) =", cka(K_SS, K_SS2))

    X_np = X_sorted.squeeze(0).cpu().numpy()
    y_np = y_sorted.squeeze(0).cpu().numpy()
    yhat_np = y_hat_all.squeeze(0).cpu().numpy()
    exX_np = example_X.squeeze(0).cpu().numpy()
    exy_np = example_y.squeeze(0).cpu().numpy()

    k_poly = polynomial_kernel(exX_np, exX_np, degree=3)

    # to numpy
    k1 = K_SS.detach().float().cpu().numpy()
    k2 = K_SS2.detach().float().cpu().numpy()

    # shared scale for the two Grams
    vmin = float(min(k1.min(), k2.min()))
    vmax = float(max(k1.max(), k2.max()))

    # difference (use symmetric diverging scale)
    d = k1 - k2
    dmax = float(np.abs(d).max())

    fig, axs = plt.subplots(1, 3, figsize=(14, 4), constrained_layout=True)

    im0 = axs[0].imshow(k1, vmin=vmin, vmax=vmax)
    axs[0].set_title("Linear Deep Kernel (K_DKL)")
    axs[0].set_xticks([])
    axs[0].set_yticks([])

    im1 = axs[1].imshow(k2, vmin=vmin, vmax=vmax)
    axs[1].set_title("Function Encoder (K_FE)")
    axs[1].set_xticks([])
    axs[1].set_yticks([])

    im2 = axs[2].imshow(k_poly)
    axs[2].set_title("Polynomial Kernel (K_poly)")
    axs[2].set_xticks([])
    axs[2].set_yticks([])

    # one colorbar for the first two (shared scale), one for the diff
    fig.colorbar(im1, ax=[axs[0], axs[1]], shrink=0.8)

    plt.savefig("grams_side_by_side_with_diff.png",
                dpi=150, bbox_inches="tight")
    plt.show()

    import numpy as np
    np.savetxt("gram_lin_kernel.csv",
               K_SS.detach().cpu().numpy(), delimiter=",")
    np.savetxt("gram_fe.csv",
               K_SS2.detach().cpu().numpy(), delimiter=",")
    np.savetxt("X_support.csv",
               example_X.squeeze(0).detach().cpu().numpy(), delimiter=",")
    np.savetxt("y_support.csv",
               example_y.squeeze(0).detach().cpu().numpy(), delimiter=",")

    fig, ax = plt.subplots()
    ax.plot(X_np, y_np, label="True")
    ax.plot(X_np, yhat_np, label="Predicted")
    ax.scatter(exX_np, exy_np, label="Support (example)", color="red")
    ax.legend()
    plt.show()
    plt.savefig("polynomial_dkl.png")
